Hybrid Transformer Network for Deepfake Detection¶

"https://arxiv.org/pdf/2208.05820"

In [1]:
#https://github.com/sfimediafutures/Hybrid-Transformer-Network-for-Deepfake-Detection/tree/main contains many erros
In [2]:
!pip install einops
!pip install dlib
!pip install imutils 
Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.2/43.2 kB 2.7 MB/s eta 0:00:00
Installing collected packages: einops
Successfully installed einops-0.8.0
Collecting dlib
  Downloading dlib-19.24.5.tar.gz (3.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.3/3.3 MB 67.3 MB/s eta 0:00:00:00:01
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
Building wheels for collected packages: dlib
  Building wheel for dlib (pyproject.toml) ... done
  Created wheel for dlib: filename=dlib-19.24.5-cp310-cp310-linux_x86_64.whl size=3375717 sha256=7730dd09a2338ebee406df1b52b74d109bbd6dc24605bbf0e19f418e7d160c51
  Stored in directory: /root/.cache/pip/wheels/96/c0/80/7cda8c6ba7dc0668f4de743aaff351300f43771cb71b54b847
Successfully built dlib
Installing collected packages: dlib
Successfully installed dlib-19.24.5
Collecting imutils
  Downloading imutils-0.5.4.tar.gz (17 kB)
  Preparing metadata (setup.py) ... done
Building wheels for collected packages: imutils
  Building wheel for imutils (setup.py) ... done
  Created wheel for imutils: filename=imutils-0.5.4-py3-none-any.whl size=25834 sha256=368ccb6bcb229d3a75a054e78e80e514d8203001f6db05c101d10efb9f866d37
  Stored in directory: /root/.cache/pip/wheels/85/cf/3a/e265e975a1e7c7e54eb3692d6aa4e2e7d6a3945d29da46f2d7
Successfully built imutils
Installing collected packages: imutils
Successfully installed imutils-0.5.4

image.png

In [3]:
from __future__ import print_function
import glob
from itertools import chain
import os
import cv2
import random
import zipfile
import os.path as osp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from functools import reduce
import torch.nn as nn
from einops import rearrange, repeat
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.transforms import ToPILImage
# from linformer import Linformer
from PIL import Image
from imgaug import augmenters as iaa
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
# from vit_pytorch.efficient import ViT
# from model import BiSeNet
import torchvision.transforms as transforms
from skimage import io, img_as_float
import timm
import urllib.request as urlreq
In [63]:
LBFmodel_url = "https://github.com/kurnianggoro/GSOC2017/raw/master/data/lbfmodel.yaml"

# save facial landmark detection model's name as LBFmodel
LBFmodel = "lbfmodel.yaml"

# check if file is in working directory
if (LBFmodel in os.listdir(os.curdir)):
    print("File exists")
else:
    # download picture from url and save locally as lbfmodel.yaml, < 54MB
    urlreq.urlretrieve(LBFmodel_url, LBFmodel)
    print("File downloaded")
File exists
In [64]:
def get_base_config():
    """Base ViT config ViT"""
    return dict(
      dim=768,
      ff_dim=3072,
      num_heads=12,
      num_layers=12,
      attention_dropout_rate=0.0,
      dropout_rate=0.1,
      representation_size=768,
      classifier='token'
    )

def get_b16_config():
    """Returns the ViT-B/16 configuration."""
    config = get_base_config()
    config.update(dict(patches=(16, 16)))
    return config

def get_b32_config():
    """Returns the ViT-B/32 configuration."""
    config = get_b16_config()
    config.update(dict(patches=(32, 32)))
    return config

def get_l16_config():
    """Returns the ViT-L/16 configuration."""
    config = get_base_config()
    config.update(dict(
        patches=(16, 16),
        dim=1024,
        ff_dim=4096,
        num_heads=16,
        num_layers=24,
        attention_dropout_rate=0.0,
        dropout_rate=0.1,
        representation_size=1024
    ))
    return config

def get_l32_config():
    """Returns the ViT-L/32 configuration."""
    config = get_l16_config()
    config.update(dict(patches=(32, 32)))
    return config

def drop_head_variant(config):
    config.update(dict(representation_size=None))
    return config


PRETRAINED_MODELS = {
    'B_16': {
      'config': get_b16_config(),
      'num_classes': 21843,
      'image_size': (224, 224),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16.pth"
    },
    'B_32': {
      'config': get_b32_config(),
      'num_classes': 21843,
      'image_size': (224, 224),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32.pth"
    },
    'L_16': {
      'config': get_l16_config(),
      'num_classes': 21843,
      'image_size': (224, 224),
      'url': None
    },
    'L_32': {
      'config': get_l32_config(),
      'num_classes': 21843,
      'image_size': (224, 224),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32.pth"
    },
    'B_16_imagenet1k': {
      'config': drop_head_variant(get_b16_config()),
      'num_classes': 1000,
      'image_size': (384, 384),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16_imagenet1k.pth"
    },
    'B_32_imagenet1k': {
      'config': drop_head_variant(get_b32_config()),
      'num_classes': 1000,
      'image_size': (384, 384),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32_imagenet1k.pth"
    },
    'L_16_imagenet1k': {
      'config': drop_head_variant(get_l16_config()),
      'num_classes': 1000,
      'image_size': (384, 384),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_16_imagenet1k.pth"
    },
    'L_32_imagenet1k': {
      'config': drop_head_variant(get_l32_config()),
      'num_classes': 1000,
      'image_size': (384, 384),
      'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32_imagenet1k.pth"
    },
}
In [65]:
txt_to_csv = False
device = 'cuda' if torch.cuda.is_available() else 'cpu'
DIR_PATH = "/kaggle/input/deepfake/phase1"
TRAIN_DIR = "/kaggle/input/deepfake/phase1/trainset"
TEST_DIR = "/kaggle/input/deepfake/phase1/valset"
OUTPUT_DIR = "/kaggle/working/"
class CFG : 
    seed = 42
    n_fold = 5
    target_col = 'target'
    train=True
    inference=False
    pseudo_labeling = True
    num_classes = 2 #binary class
    trn_fold=[0, 1]
    debug=False
    apex=False
    print_freq=20 #every how many batch the scores get showed
    num_workers=4
#     model_name="eva02_large_patch14_448.mim_m38m_ft_in22k_in1k"
    model_name=  "efficientnet_b3"
    size=448
    scheduler='CosineAnnealingWarmRestarts' 
    epochs=2
    lr=1e-4
    min_lr=1e-6
    T_0=10 # CosineAnnealingWarmRestarts
    batch_size=20
    weight_decay=1e-6
    gradient_accumulation_steps=1
    max_grad_norm=1000
    paper_name = "Hybrid-Transformer-Network-for-Deepfake-Detection"
    
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(CFG.seed)
In [66]:
train = pd.read_csv(f"{DIR_PATH +'/trainset_label.txt'}")
test = pd.read_csv(f"{DIR_PATH +'/valset_label.txt'}")
In [67]:
train['img_name'] = TRAIN_DIR+'/'+train['img_name']
In [68]:
train_list = train['img_name'].values
labels = train[CFG.target_col].values
In [69]:
random_idx = np.random.randint(1, len(train_list), size=9)
fig, axes = plt.subplots(3, 3, figsize=(16, 12))

for idx, ax in enumerate(axes.ravel()):
    img = Image.open(train_list[idx])
    ax.set_title(labels[idx])
    ax.imshow(img)
In [70]:
import albumentations as A
import numpy as np
import torchvision.transforms as T
In [71]:
class AlbumentationsTransform:
    def __init__(self):
        self.aug = A.Compose([
            A.Resize(height=224, width=224, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.OneOf([
                A.RandomScale(scale_limit=0.5, p=1.0),
                A.Rotate(limit=20, p=1.0),
                A.ShiftScaleRotate(scale_limit=0.2, rotate_limit=20, p=1.0, border_mode=0),
                A.CoarseDropout(max_holes=1, max_height=50, max_width=50, min_holes=1, min_height=20, min_width=20, fill_value=0, p=1.0),
            ], p=1.0),
            A.Resize(height=224, width=224, p=1.0)  # Ensure resize after all other transformations
        ])

    def __call__(self, img):
        img = np.array(img)
        img = self.aug(image=img)['image']
        img = T.ToTensor()(img)
        img = T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(img)
        return img
    
       
transforms_imgaug = AlbumentationsTransform()

train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
#         torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
#         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
#         torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
#         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)
In [72]:
class DeepFakeSet(Dataset):
    def __init__(self, data_set, transform=None):
        self.data_set = data_set
        self.file_list = data_set['img_name'].values
        self.transform = transform
        self.to_img = ToPILImage()
        self.detector = dlib.get_frontal_face_detector()
        self.predictor = dlib.shape_predictor("/kaggle/input/81_face_landmarks/other/default/1/shape_predictor_81_face_landmarks.dat")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path).convert('RGB')
        img_np = np.array(img)
#         print(f'Original image shape: {img_np.shape}')
        
        random_num = torch.randint(1, 10, (1,))
        augmented_face = augment_face(self.detector, self.predictor, img_path, random_num)
        
        img_transformed = self.transform(self.to_img(augmented_face)) if self.transform else self.to_img(augmented_face)
#         print(f'Transformed image shape: {img_transformed.shape}')
        label = self.data_set[self.data_set['img_name'] == img_path][CFG.target_col].values[0]
        
        return img_transformed, label
In [73]:
train
Out[73]:
img_name target
0 /kaggle/input/deepfake/phase1/trainset/3381ccb... 1
1 /kaggle/input/deepfake/phase1/trainset/63fee8a... 0
2 /kaggle/input/deepfake/phase1/trainset/7eb4553... 0
3 /kaggle/input/deepfake/phase1/trainset/9200859... 1
4 /kaggle/input/deepfake/phase1/trainset/f632068... 1
... ... ...
524424 /kaggle/input/deepfake/phase1/trainset/1af9be6... 1
524425 /kaggle/input/deepfake/phase1/trainset/fa3c2a3... 1
524426 /kaggle/input/deepfake/phase1/trainset/d639604... 1
524427 /kaggle/input/deepfake/phase1/trainset/c477803... 1
524428 /kaggle/input/deepfake/phase1/trainset/03e223e... 1

524429 rows × 2 columns

In [74]:
train_list, valid_list = train_test_split(train, 
                                          test_size=0.2,
#                                           stratify=labels,
                                          random_state=CFG.seed)
In [86]:
class DeepFakeSet(Dataset):
    def __init__(self, file_list, transform=None):

        self.file_list = file_list
        self.transform = transform
        self.to_img = ToPILImage()
        self.detector = dlib.get_frontal_face_detector()
        self.predictor = dlib.shape_predictor("/kaggle/input/81_face_landmarks/other/default/1/shape_predictor_81_face_landmarks.dat")

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list.iloc[idx]['img_name']
        img = Image.open(img_path)
        random_num = torch.randint(1, 10, (1,))
        augmented_face = augment_face(self.detector, self.predictor, img_path, random_num)
        img_transformed = self.transform(self.to_img(augmented_face))

        label = self.file_list.iloc[idx]['target']
#         label = 1 if label == "real" else 0

        return img_transformed, label
In [87]:
import dlib
from imutils import face_utils
In [88]:
def split_last(x, shape):
    "split the last dimension to given shape"
    shape = list(shape)
    assert shape.count(-1) <= 1
    if -1 in shape:
        shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
    return x.view(*x.size()[:-1], *shape)


def merge_last(x, n_dims):
    "merge the last n_dims to a dimension"
    s = x.size()
    assert n_dims > 1 and n_dims < len(s)
    return x.view(*s[:-n_dims], -1)


class MultiHeadedSelfAttention(nn.Module):
    """Multi-Headed Dot Product Attention"""
    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        self.proj_q = nn.Linear(dim, dim)
        self.proj_k = nn.Linear(dim, dim)
        self.proj_v = nn.Linear(dim, dim)
        self.drop = nn.Dropout(dropout)
        self.n_heads = num_heads
        self.scores = None # for visualization

    def forward(self, x, mask):
        """
        x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
        mask : (B(batch_size) x S(seq_len))
        * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
        """
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
        q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
        # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        if mask is not None:
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask)
        scores = self.drop(F.softmax(scores, dim=-1))
        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
        h = (scores @ v).transpose(1, 2).contiguous()
        # -merge-> (B, S, D)
        h = merge_last(h, 2)
        self.scores = scores
        return h


class PositionWiseFeedForward(nn.Module):
    """FeedForward Neural Networks for each position"""
    def __init__(self, dim, ff_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, ff_dim)
        self.fc2 = nn.Linear(ff_dim, dim)

    def forward(self, x):
        # (B, S, D) -> (B, S, D_ff) -> (B, S, D)
        return self.fc2(F.gelu(self.fc1(x)))


class Block(nn.Module):
    """Transformer Block"""
    def __init__(self, dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.attn = MultiHeadedSelfAttention(dim, num_heads, dropout)
        self.proj = nn.Linear(dim, dim)
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.pwff = PositionWiseFeedForward(dim, ff_dim)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, mask):
        h = self.drop(self.proj(self.attn(self.norm1(x), mask)))
        x = x + h
        h = self.drop(self.pwff(self.norm2(x)))
        x = x + h
        return x


class Transformer(nn.Module):
    """Transformer with Self-Attentive Blocks"""
    def __init__(self, num_layers, dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask)
        return x
In [89]:
"""utils.py - Helper functions
"""

import numpy as np
import torch
from torch.utils import model_zoo

# from configs import PRETRAINED_MODELS


def load_pretrained_weights(
    model, 
    model_name=None, 
    weights_path=None, 
    load_first_conv=False, 
    load_fc=True, 
    load_repr_layer=False,
    resize_positional_embedding=False,
    verbose=True,
    strict=False,
):
    """Loads pretrained weights from weights path or download using url.
    Args:
        model (Module): Full model (a nn.Module)
        model_name (str): Model name (e.g. B_16)
        weights_path (None or str):
            str: path to pretrained weights file on the local disk.
            None: use pretrained weights downloaded from the Internet.
        load_first_conv (bool): Whether to load patch embedding.
        load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
        resize_positional_embedding=False,
        verbose (bool): Whether to print on completion
    """
    assert bool(model_name) ^ bool(weights_path), 'Expected exactly one of model_name or weights_path'
    
    # Load or download weights
    if weights_path is None:
        url = PRETRAINED_MODELS[model_name]['url']
        if url:
            state_dict = model_zoo.load_url(url)
        else:
            raise ValueError(f'Pretrained model for {model_name} has not yet been released')
    else:
        state_dict = torch.load(weights_path)

    # Modifications to load partial state dict
    expected_missing_keys = []
    if not load_first_conv and 'patch_embedding.weight' in state_dict:
        expected_missing_keys += ['patch_embedding.weight', 'patch_embedding.bias']
    if not load_fc and 'fc.weight' in state_dict:
        expected_missing_keys += ['fc.weight', 'fc.bias']
    if not load_repr_layer and 'pre_logits.weight' in state_dict:
        expected_missing_keys += ['pre_logits.weight', 'pre_logits.bias']
    for key in expected_missing_keys:
        state_dict.pop(key)

    # Change size of positional embeddings
    if resize_positional_embedding: 
        posemb = state_dict['positional_embedding.pos_embedding']
        posemb_new = model.state_dict()['positional_embedding.pos_embedding']
        state_dict['positional_embedding.pos_embedding'] = \
            resize_positional_embedding_(posemb=posemb, posemb_new=posemb_new, 
                has_class_token=hasattr(model, 'class_token'))
        maybe_print('Resized positional embeddings from {} to {}'.format(
                    posemb.shape, posemb_new.shape), verbose)

    # Load state dict
    ret = model.load_state_dict(state_dict, strict=False)
    if strict:
        assert set(ret.missing_keys) == set(expected_missing_keys), \
            'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
        assert not ret.unexpected_keys, \
            'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
        maybe_print('Loaded pretrained weights.', verbose)
    else:
        maybe_print('Missing keys when loading pretrained weights: {}'.format(ret.missing_keys), verbose)
        maybe_print('Unexpected keys when loading pretrained weights: {}'.format(ret.unexpected_keys), verbose)
        return ret


def maybe_print(s: str, flag: bool):
    if flag:
        print(s)


def as_tuple(x):
    return x if isinstance(x, tuple) else (x, x)


def resize_positional_embedding_(posemb, posemb_new, has_class_token=True):
    """Rescale the grid of position embeddings in a sensible manner"""
    from scipy.ndimage import zoom

    # Deal with class token
    ntok_new = posemb_new.shape[1]
    if has_class_token:  # this means classifier == 'token'
        posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
        ntok_new -= 1
    else:
        posemb_tok, posemb_grid = posemb[:, :0], posemb[0]

    # Get old and new grid sizes
    gs_old = int(np.sqrt(len(posemb_grid)))
    gs_new = int(np.sqrt(ntok_new))
    posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)

    # Rescale grid
    zoom_factor = (gs_new / gs_old, gs_new / gs_old, 1)
    posemb_grid = zoom(posemb_grid, zoom_factor, order=1)
    posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
    posemb_grid = torch.from_numpy(posemb_grid)

    # Deal with class token and return
    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
    return posemb
In [91]:
train_data = DeepFakeSet(train_list, transform=transforms_imgaug)
valid_data = DeepFakeSet(valid_list, transform=val_transforms)
# test_data = DeepFakeSet(test_list, transform=test_transforms)

batch_size = 12

train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)
# # test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)

print(len(train_data), len(train_loader))
print(len(valid_data), len(valid_loader))
419543 34962
104886 8741
In [95]:
to_img = ToPILImage()

random_idx = np.random.randint(1, len(train_data), size=9)
fig, axes = plt.subplots(3, 3, figsize=(16, 12))

for idx, ax in enumerate(axes.ravel()):
    img = (to_img(train_data[idx][0]))
    ax.set_title(train_data[idx][1])
    ax.imshow(img)
In [96]:
"""model.py - Model and module class for ViT.
   They are built to mirror those in the official Jax implementation.
"""

from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F


# from transformer import Transformer
# from utils_ViT import load_pretrained_weights, as_tuple
# from configs import PRETRAINED_MODELS

class Transformer(nn.Module):
    """Transformer with Self-Attentive Blocks"""
    def __init__(self, num_layers, dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask)
        return x
    
In [97]:
class PositionalEmbedding1D(nn.Module):
    """Adds (optionally learned) positional embeddings to the inputs."""

    def __init__(self, seq_len, dim):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, dim))
    
    def forward(self, x):
        """Input has shape `(batch_size, seq_len, emb_dim)`"""
        return x + self.pos_embedding
In [98]:
class SequenceEmbedding(nn.Module):
    """Adds (optionally learned) sequence embeddings to the inputs."""

    def __init__(self, seq_len, dim):
        super().__init__()
        self.seq_embedding = nn.Parameter(torch.zeros(1, 162, dim))
    
    def forward(self, x):
        """Input has shape `(batch_size, seq_len, emb_dim)`"""
        return x + self.seq_embedding
In [99]:
class ViT(nn.Module):
    """
    Args:
        name (str): Model name, e.g. 'B_16'
        pretrained (bool): Load pretrained weights
        in_channels (int): Number of channels in input data
        num_classes (int): Number of classes, default 1000

    References:
        [1] https://openreview.net/forum?id=YicbFdNTTy
    """

    def __init__(
        self, 
        name: Optional[str] = None, 
        pretrained: bool = False, 
        patches: int = 16,
        dim: int = 768,
        ff_dim: int = 3072,
        num_heads: int = 12,
        num_layers: int = 12,
        attention_dropout_rate: float = 0.0,
        dropout_rate: float = 0.1,
        representation_size: Optional[int] = None,
        load_repr_layer: bool = False,
        classifier: str = 'token',
        positional_embedding: str = '1d',
        in_channels: int = 3, 
        image_size: Optional[int] = None,
        num_classes: Optional[int] = None,
# from utils_ViT import load_pretrained_weights, as_tuple
# from configs import PRETRAINED_MODELS
    ):
        super().__init__()

        # Configuration
        if name is None:
            check_msg = 'must specify name of pretrained model'
            assert not pretrained, check_msg
            assert not resize_positional_embedding, check_msg
            if num_classes is None:
                num_classes = 1000
            if image_size is None:
                image_size = 384
        else:  # load pretrained model
            assert name in PRETRAINED_MODELS.keys(), \
                'name should be in: ' + ', '.join(PRETRAINED_MODELS.keys())
            config = PRETRAINED_MODELS[name]['config']
            patches = config['patches']
            dim = config['dim']
            ff_dim = config['ff_dim']
            num_heads = config['num_heads']
            num_layers = config['num_layers']
            attention_dropout_rate = config['attention_dropout_rate']
            dropout_rate = config['dropout_rate']
            representation_size = config['representation_size']
            classifier = config['classifier']
            if image_size is None:
                image_size = PRETRAINED_MODELS[name]['image_size']
            if num_classes is None:
                num_classes = PRETRAINED_MODELS[name]['num_classes']
        self.image_size = image_size                

        # Image and patch sizes
        h, w = as_tuple(image_size)  # image sizes
        fh, fw = as_tuple(patches)  # patch sizes
        gh, gw = h // fh, w // fw  # number of patches
        seq_len = gh * gw

        def get_feature_extractor():
            feature_extractor = timm.create_model('xception', features_only=True, pretrained=True)
            feature_extractor.eval()
            return feature_extractor
#         self.parse_net = ViT.load_face_parser()
        self.xception_feature_extractor = get_feature_extractor()
    
        def get_resnet():
            feature_extractor = timm.create_model('efficientnet_b4', features_only=True, pretrained=True)
#             feature_extractor.eval()
            return feature_extractor
#         self.parse_net = ViT.load_face_parser()
        self.resnet_feature_extractor = get_resnet()
    
        # Patch embedding
#         self.patch_embedding = nn.Conv2d(in_channels, dim, kernel_size=(14,28), stride=(14, 28))
        
#         self.sequence_embedding = SequenceEmbedding(162, dim)
        self.linear_1 = torch.nn.Linear(49, 162)
#         self.linear_2 = torch.nn.Linear(2048, 768)
        self.proj = nn.Conv2d(2048, 768, 1)
    
#         self.linear_2 = torch.nn.Linear(100, 162)
        self.proj_2 = nn.Conv2d(448, 768, 1)
                
#       Class token
        if classifier == 'token':
            self.class_token = nn.Parameter(torch.zeros(1, 1, dim))
            seq_len += 1
        
#       Positional embedding
        if positional_embedding.lower() == '1d':
            self.positional_embedding = PositionalEmbedding1D(seq_len, dim)
        else:
            raise NotImplementedError()
                    
        #Transformer
        self.transformer = Transformer(num_layers=num_layers, dim=dim, num_heads=num_heads, 
                                       ff_dim=ff_dim, dropout=dropout_rate)
        
        # Representation layer
        if representation_size and load_repr_layer:
            self.pre_logits = nn.Linear(dim, representation_size)
            pre_logits_size = representation_size
        else:
            pre_logits_size = dim

        # Classifier head
        self.norm = nn.LayerNorm(pre_logits_size, eps=1e-6)
        self.fc = nn.Linear(pre_logits_size, num_classes)

        # Initialize weights
        self.init_weights()
        
        # Load pretrained model
        if pretrained:
            pretrained_num_channels = 3
            pretrained_num_classes = PRETRAINED_MODELS[name]['num_classes']
            pretrained_image_size = PRETRAINED_MODELS[name]['image_size']
            load_pretrained_weights(
                self, name, 
                load_fc=(num_classes == pretrained_num_classes),
                load_repr_layer=load_repr_layer,
                resize_positional_embedding=(image_size != pretrained_image_size),
            )
            
        
    @torch.no_grad()
    def init_weights(self):
        def _init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)  # _trunc_normal(m.weight, std=0.02)  # from .initialization import _trunc_normal
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)  # nn.init.constant(m.bias, 0)
        self.apply(_init)
        nn.init.constant_(self.fc.weight, 0)
        nn.init.constant_(self.fc.bias, 0)
        nn.init.normal_(self.positional_embedding.pos_embedding, std=0.02)  # _trunc_normal(self.positional_embedding.pos_embedding, std=0.02)
        nn.init.constant_(self.class_token, 0)

    def forward(self, x):
        """Breaks image into patches, applies transformer, applies MLP head.
        print(data.shape)
         Args:
            x (tensor): `b,c,fh,fw`
        """
        b, c, fh, fw = x.shape
#         x = DDFA.generate_uv_tex(x, self.xception_feature_extractor, self.sequence_embedding, self.linear_1, self.proj)

        x1 = self.xception_feature_extractor(x)
        x1 = x1[-1]
        x1 = self.proj(x1).flatten(2)
        x1 = self.linear_1(x1)
        x1 = x1.transpose(2,1)
#         x1 = self.sequence_embedding(x1)
#         print(x1.shape)
        
        x2 = self.resnet_feature_extractor(x)
        x2 = x2[-1]
        x2 = self.proj_2(x2).flatten(2)
        x2 = self.linear_1(x2)
        x2 = x2.transpose(2,1)
#         x2 = self.sequence_embedding(x2)
#         print(x2.shape)
        
        x = torch.cat((x1, x2), dim=1)
#         print(x.shape)
#         x = self.patch_embedding(x)  # b,d,gh,gw
#         print(x.shape)
#         x = x.flatten(2).transpose(1, 2)  # b,gh*gw,d
#         print(x.shape)
        if hasattr(self, 'class_token'):
            x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1)  # b,gh*gw+1,d
        if hasattr(self, 'positional_embedding'): 
            x = self.positional_embedding(x)  # b,gh*gw+1,d 

        x = self.transformer(x)  # b,gh*gw+1,d
        if hasattr(self, 'pre_logits'):
            x = self.pre_logits(x)
            x = torch.tanh(x)
        if hasattr(self, 'fc'):
            x = self.norm(x)[:, 0]  # b,d
            x = self.fc(x)  # b,num_classes
        return x
In [100]:
def face_remap(shape):
    remapped_image = cv2.convexHull(np.array(shape, dtype=np.float32))
    return remapped_image

def augment_face(detector, predictor, img, random_num):
    # Check if img is a file path or a NumPy array
    if isinstance(img, str):
        frame = cv2.imread(img)
    else:
        frame = img.copy()

    # Ensure the frame is in the correct format
    if frame is None:
        raise ValueError("Image could not be loaded.")

    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    out_face = np.zeros_like(frame, dtype=np.uint8)
    feature_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)

    # Detect faces
    faces = detector(gray, 1)
    if not faces:
        return frame

    for face in faces:
        landmarks = predictor(gray, face)
        shape = face_utils.shape_to_np(landmarks)
        shape = np.array(shape, dtype=np.float32)  # Ensure shape is in float32

        # Process based on random_num value
        if random_num == 1:
            for indices in [shape[1:12], shape[3:14], shape[6:17]]:
                remapped_shape = face_remap(indices)
                cv2.fillConvexPoly(feature_mask, remapped_shape.astype(np.int32), 1)

        elif random_num == 2:
            for indices in [shape[37:42], shape[43:48], shape[49:68]]:
                remapped_shape = face_remap(indices)
                cv2.fillConvexPoly(feature_mask, remapped_shape.astype(np.int32), 1)

        elif random_num == 3:
            for indices in [shape[28:36], shape[37:42], shape[43:48], shape[49:68]]:
                remapped_shape = face_remap(indices)
                cv2.fillConvexPoly(feature_mask, remapped_shape.astype(np.int32), 1)

        elif random_num == 4:
            remapped_shape = face_remap(shape[1:12])
            cv2.fillConvexPoly(feature_mask, remapped_shape.astype(np.int32), 1)

        elif random_num == 5:
            remapped_shape = face_remap(shape[6:17])
            cv2.fillConvexPoly(feature_mask, remapped_shape.astype(np.int32), 1)

        elif random_num == 6:
            for indices in [shape[28:36], shape[49:68]]:
                remapped_shape = face_remap(indices)
                cv2.fillConvexPoly(feature_mask, remapped_shape.astype(np.int32), 1)

        elif random_num == 7:
            remapped_shape = face_remap(shape[49:68])
            cv2.fillConvexPoly(feature_mask, remapped_shape.astype(np.int32), 1)

        elif random_num == 8:
            for indices in [shape[37:42], shape[43:48], shape[28:46]]:
                remapped_shape = face_remap(indices)
                cv2.fillConvexPoly(feature_mask, remapped_shape.astype(np.int32), 1)

        # Apply the feature mask
        out_face[feature_mask == 1] = frame[feature_mask == 1]

        # Subtract the face from the original frame
        frame = frame - out_face
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    return frame
In [102]:
%%capture
model = ViT('B_16_imagenet1k', pretrained=True,
    image_size = 300,
    num_classes = 2).to(device)
In [103]:
epochs = 1
lr = 3e-3
# gamma = 0.7

# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.SGD(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
In [ ]:
# NO SegEmbed f2f only eval on with_new_train_valid_strategy
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)
        output = model(data)
        loss = criterion(output, label)
#         print("HERE")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in tqdm(valid_loader):
            data = data.to(device)
            label = label.to(device)
#             print(data.shape)aa
            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
    
    if epoch in [2, 3, 4, 5, 6, 7, 8, 9, 10]:
            PATH = OUTPUT_DIR+f'{CFG.paper_name}_fold{epoch}_save.pth'
            torch.save(model.state_dict(), PATH)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )
  0%|          | 0/34962 [00:00<?, ?it/s]